{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28c7f468-79e9-48e6-9e6f-51dac6116fb0",
   "metadata": {},
   "source": [
    "# install requirements\n",
    "!pip install falkordb openai burr[graphviz] "
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "5a2c326f-53cc-4eb1-8508-e06584d8051c",
   "metadata": {},
   "source": [
    "# Question & answer notebook\n",
    "\n",
    "This notebook walks you through how to build a burr application that talks to falkorDB and openai to answer questions about UFC fights."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "39233706-20fd-4043-bc0f-291626664f08",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-27T16:45:35.743153Z",
     "start_time": "2024-05-27T16:45:35.716206Z"
    }
   },
   "source": [
    "# import what we need\n",
    "import json\n",
    "from typing import Tuple\n",
    "\n",
    "import openai\n",
    "from burr.core import ApplicationBuilder, State, default, expr, Application\n",
    "from burr.core.action import action\n",
    "from burr.tracking import LocalTrackingClient\n",
    "import uuid\n",
    "from falkordb import FalkorDB\n",
    "from graph_schema import graph_schema\n",
    "import falkordb"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "a2fe5f83-75b4-4d44-a80e-829947b6d240",
   "metadata": {},
   "source": [
    "## Helper functions\n",
    "We first set up some helper functions that we'll use."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ef809095-566a-40af-a8a6-cbf72ef21227",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-27T16:45:40.907749Z",
     "start_time": "2024-05-27T16:45:40.894703Z"
    }
   },
   "source": [
    "def schema_to_prompt(schema):\n",
    "    \"\"\"Prompt to help tell the LLM what is in the graph DB\"\"\"\n",
    "    prompt = \"The Knowledge graph contains nodes of the following types:\\n\"\n",
    "\n",
    "    for node in schema['nodes']:\n",
    "        lbl = node\n",
    "        node = schema['nodes'][node]\n",
    "        if len(node['attributes']) > 0:\n",
    "            prompt += f\"The {lbl} node type has the following set of attributes:\\n\"\n",
    "            for attr in node['attributes']:\n",
    "                t = node['attributes'][attr]['type']\n",
    "                prompt += f\"The {attr} attribute is of type {t}\\n\"\n",
    "        else:\n",
    "            prompt += f\"The {node} node type has no attributes:\\n\"\n",
    "\n",
    "    prompt += \"In addition the Knowledge graph contains edge of the following types:\\n\"\n",
    "\n",
    "    for edge in schema['edges']:\n",
    "        rel = edge\n",
    "        edge = schema['edges'][edge]\n",
    "        if len(edge['attributes']) > 0:\n",
    "            prompt += f\"The {rel} edge type has the following set of attributes:\\n\"\n",
    "            for attr in edge['attributes']:\n",
    "                t = edge['attributes'][attr]['type']\n",
    "                prompt += f\"The {attr} attribute is of type {t}\\n\"\n",
    "        else:\n",
    "            prompt += f\"The {rel} edge type has no attributes:\\n\"\n",
    "\n",
    "        prompt += f\"The {rel} edge connects the following entities:\\n\"\n",
    "        for conn in edge['connects']:\n",
    "            src = conn[0]\n",
    "            dest = conn[1]\n",
    "            prompt += f\"{src} is connected via {rel} to {dest}, (:{src})-[:{rel}]->(:{dest})\\n\"\n",
    "\n",
    "    return prompt\n",
    "\n",
    "def set_inital_chat_history(schema_prompt: str) -> list[dict]:\n",
    "    \"\"\"Helper to set initial system message\"\"\"\n",
    "    SYSTEM_MESSAGE = \"You are a Cypher expert with access to a directed knowledge graph\\n\"\n",
    "    SYSTEM_MESSAGE += schema_prompt\n",
    "    SYSTEM_MESSAGE += (\"Query the knowledge graph to extract relevant information to help you anwser the users \"\n",
    "                       \"questions, base your answer only on the context retrieved from the knowledge graph, \"\n",
    "                       \"do not use preexisting knowledge.\")\n",
    "    SYSTEM_MESSAGE += (\"For example to find out if two fighters had fought each other e.g. did Conor McGregor \"\n",
    "                       \"every compete against Jose Aldo issue the following query: \"\n",
    "                       \"MATCH (a:Fighter)-[]->(f:Fight)<-[]-(b:Fighter) WHERE a.Name = 'Conor McGregor' AND \"\n",
    "                       \"b.Name = 'Jose Aldo' RETURN a, b\\n\")\n",
    "\n",
    "    messages = [{\"role\": \"system\", \"content\": SYSTEM_MESSAGE}]\n",
    "    return messages"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "9a78e2eb1ab4c1ef",
   "metadata": {},
   "source": [
    "## Tools\n",
    "Here we describe the tool openAI will use & it's schema that will be passed to describe it.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7134877d85934e9d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-27T16:45:54.998999Z",
     "start_time": "2024-05-27T16:45:54.991226Z"
    }
   },
   "source": [
    "def run_cypher_query(graph, query):\n",
    "    \"\"\"What executes a query against falkorDB\"\"\"\n",
    "    try:\n",
    "        results = graph.ro_query(query).result_set\n",
    "    except:\n",
    "        results = {\"error\": \"Query failed please try a different variation of this query\"}\n",
    "\n",
    "    if len(results) == 0:\n",
    "        results = {\n",
    "            \"error\": \"The query did not return any data, please make sure you're using the right edge \"\n",
    "                     \"directions and you're following the correct graph schema\"}\n",
    "\n",
    "    return str(results)\n",
    "\n",
    "# description\n",
    "run_cypher_query_tool_description = {\n",
    "    \"type\": \"function\",\n",
    "    \"function\": {\n",
    "        \"name\": \"run_cypher_query\",\n",
    "        \"description\": \"Runs a Cypher query against the knowledge graph\",\n",
    "        \"parameters\": {\n",
    "            \"type\": \"object\",\n",
    "            \"properties\": {\n",
    "                \"query\": {\n",
    "                    \"type\": \"string\",\n",
    "                    \"description\": \"Query to execute\",\n",
    "                },\n",
    "            },\n",
    "            \"required\": [\"query\"],\n",
    "        },\n",
    "    },\n",
    "}"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "8f1987f6af0eb2b6",
   "metadata": {},
   "source": [
    "## Actions\n",
    "Let's now define the actions that our application will make and what they read from & write to with respect to state.\n",
    "\n",
    "We'll define four of them:\n",
    "\n",
    "1. Human converse: This action will take the user's question and store it in the state.\n",
    "2. AI create cypher query: This action will use the user's question to create a cypher query.\n",
    "3. Tool call: This action will execute the cypher query and append the result to the chat history.\n",
    "4. AI response: This action will take the result of the cypher query and create a response."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6c1e5d3a8bc04ec6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-27T16:46:20.736830Z",
     "start_time": "2024-05-27T16:46:20.729995Z"
    }
   },
   "source": [
    "@action(\n",
    "    reads=[],\n",
    "    writes=[\"question\", \"chat_history\"],\n",
    ")\n",
    "def human_converse(state: State, user_question: str) -> Tuple[dict, State]:\n",
    "    \"\"\"Human converse step -- make sure we get input, and store it as state.\"\"\"\n",
    "    new_state = state.update(question=user_question)\n",
    "    new_state = new_state.append(chat_history={\"role\": \"user\", \"content\": user_question})\n",
    "    return {\"question\": user_question}, new_state"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a29aeb9b7b025591",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-27T16:46:24.286410Z",
     "start_time": "2024-05-27T16:46:24.277468Z"
    }
   },
   "source": [
    "@action(\n",
    "    reads=[\"question\", \"chat_history\"],\n",
    "    writes=[\"chat_history\", \"tool_calls\"],\n",
    ")\n",
    "def AI_create_cypher_query(state: State, client: openai.Client) -> tuple[dict, State]:\n",
    "    \"\"\"AI step to create the cypher query.\"\"\"\n",
    "    messages = state[\"chat_history\"]\n",
    "    # Call the function\n",
    "    response = client.chat.completions.create(\n",
    "        model=\"gpt-4-turbo-preview\",\n",
    "        messages=messages,\n",
    "        tools=[run_cypher_query_tool_description],\n",
    "        tool_choice=\"auto\",\n",
    "    )\n",
    "    response_message = response.choices[0].message\n",
    "    new_state = state.append(chat_history=response_message.to_dict())\n",
    "    tool_calls = response_message.tool_calls\n",
    "    if tool_calls:\n",
    "        new_state = new_state.update(tool_calls=tool_calls)\n",
    "    # if there are no tool calls -- it means we didn't know what to do\n",
    "    return {\"ai_response\": response_message.content, \"usage\": response.usage.to_dict()}, new_state\n"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a945e875be76be26",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-27T16:46:28.364022Z",
     "start_time": "2024-05-27T16:46:28.347327Z"
    }
   },
   "source": [
    "@action(\n",
    "    reads=[\"tool_calls\", \"chat_history\"],\n",
    "    writes=[\"tool_calls\", \"chat_history\"],\n",
    ")\n",
    "def tool_call(state: State, graph: falkordb.Graph) -> Tuple[dict, State]:\n",
    "    \"\"\"Tool call step -- execute the query and append to chat history.\"\"\"\n",
    "    tool_calls = state.get(\"tool_calls\", [])\n",
    "    new_state = state\n",
    "    result = {\"tool_calls\": []}\n",
    "    for tool_call in tool_calls:\n",
    "        function_name = tool_call.function.name\n",
    "        assert (function_name == \"run_cypher_query\")\n",
    "        function_args = json.loads(tool_call.function.arguments)\n",
    "        function_response = run_cypher_query(graph, function_args.get(\"query\"))\n",
    "        new_state = new_state.append(chat_history=\n",
    "            {\n",
    "                \"tool_call_id\": tool_call.id,\n",
    "                \"role\": \"tool\",\n",
    "                \"name\": function_name,\n",
    "                \"content\": function_response,\n",
    "            }\n",
    "        )\n",
    "        result[\"tool_calls\"].append({\"tool_call_id\": tool_call.id, \"response\": function_response})\n",
    "    new_state = new_state.update(tool_calls=[])\n",
    "    return result, new_state"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "698d1ef5da45b2c6",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-27T16:47:18.853441Z",
     "start_time": "2024-05-27T16:47:18.847306Z"
    }
   },
   "source": [
    "@action(\n",
    "    reads=[\"chat_history\"],\n",
    "    writes=[\"chat_history\"],\n",
    ")\n",
    "def AI_generate_response(state: State, client: openai.Client) -> tuple[dict, State]:\n",
    "    \"\"\"AI step to generate the response given the current chat history.\"\"\"\n",
    "    messages = state[\"chat_history\"]\n",
    "    response = client.chat.completions.create(\n",
    "        model=\"gpt-4-turbo-preview\",\n",
    "        messages=messages,\n",
    "    )  # get a new response from the model where it can see the function response\n",
    "    response_message = response.choices[0].message\n",
    "    new_state = state.append(chat_history=response_message.to_dict())\n",
    "    return {\"ai_response\": response_message.content,\n",
    "            \"usage\": response.usage.to_dict()}, new_state\n"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "9fca50d97f2aca0f",
   "metadata": {},
   "source": [
    "## Define the application\n",
    "This is where we define our application now"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "66411a5074f15a7f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-27T16:47:34.643728Z",
     "start_time": "2024-05-27T16:47:34.515076Z"
    }
   },
   "source": [
    "# define our clients / connections / IDs\n",
    "openai_client = openai.OpenAI()\n",
    "db_client = FalkorDB(host='localhost', port=6379)\n",
    "graph_name = \"UFC\"\n",
    "application_run_id = str(uuid.uuid4())"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "e4d89ba2d477efb1",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-27T16:47:43.535537Z",
     "start_time": "2024-05-27T16:47:38.440468Z"
    }
   },
   "source": [
    "# get the graph\n",
    "graph = db_client.select_graph(graph_name)\n",
    "# get schema\n",
    "schema = graph_schema(graph)\n",
    "# create a prompt from it\n",
    "schema_prompt = schema_to_prompt(schema)\n",
    "# set the initial chat history\n",
    "base_messages = set_inital_chat_history(schema_prompt)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "dbe987ade3c5b69f",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-27T16:49:23.058742Z",
     "start_time": "2024-05-27T16:49:23.052222Z"
    }
   },
   "source": [
    "tracker = LocalTrackingClient(\"ufc-falkor\")\n",
    "# create graph\n",
    "burr_application = (\n",
    "    ApplicationBuilder()\n",
    "    .with_actions(  # define the actions\n",
    "        AI_create_cypher_query.bind(client=openai_client),\n",
    "        tool_call.bind(graph=graph),\n",
    "        AI_generate_response.bind(client=openai_client),\n",
    "        human_converse\n",
    "    )\n",
    "    .with_transitions(  # define the edges between the actions based on state conditions\n",
    "        (\"human_converse\", \"AI_create_cypher_query\", default),\n",
    "        (\"AI_create_cypher_query\", \"tool_call\", expr(\"len(tool_calls)>0\")),\n",
    "        (\"AI_create_cypher_query\", \"human_converse\", default),\n",
    "        (\"tool_call\", \"AI_generate_response\", default),\n",
    "        (\"AI_generate_response\", \"human_converse\", default)\n",
    "    )\n",
    "    .with_identifiers(app_id=application_run_id)\n",
    "    .with_state(  # initial state\n",
    "       **{\"chat_history\": base_messages, \"tool_calls\": []},\n",
    "    )\n",
    "    .with_entrypoint(\"human_converse\")\n",
    "    .with_tracker(tracker)\n",
    "    .build()\n",
    ")"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "5bb8432243e49d35",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-27T16:47:47.058569Z",
     "start_time": "2024-05-27T16:47:46.295042Z"
    }
   },
   "source": [
    "burr_application.visualize(include_conditions=True)"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "3ea2b29306cbb4d1",
   "metadata": {},
   "source": [
    "## Run the application\n",
    "Here we show how to do a simple loop stopping before `human_converse` each time to get user input before running the graph again.\n",
    "\n",
    "\n",
    "### Viewing a trace of the this application in the Burr UI\n",
    "Note: you can view the logs of the conversation in the Burr UI. \n",
    "\n",
    "To see that, in another terminal do:\n",
    "\n",
    "> burr\n",
    "\n",
    "You'll then have the UI running on [http://localhost:7241/](http://localhost:7241/).\n",
    "\n",
    "#### Using the Burr UI in google collab\n",
    "To use the UI in google collab do the following:\n",
    "\n",
    "1. Run this in a cell\n",
    "```python\n",
    "from google.colab import output\n",
    "output.serve_kernel_port_as_window(7241)\n",
    "```\n",
    "\n",
    "2. Then start the burr UI:\n",
    "```\n",
    "!burr &\n",
    "```\n",
    "3. Click the link in (1) to open a new tab."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "c2c996de4588512c",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-27T16:48:59.800468Z",
     "start_time": "2024-05-27T16:48:02.100003Z"
    }
   },
   "source": [
    "# run it\n",
    "while True:\n",
    "    # this will ask for input:\n",
    "    question = input(\"What can I help you with?\\n\")\n",
    "    if question == \"exit\":\n",
    "        break\n",
    "    current_action, _, current_state = burr_application.run(\n",
    "        halt_before=[\"human_converse\"],\n",
    "        inputs={\"user_question\": question},\n",
    "    )\n",
    "    # we'll then see the AI response:\n",
    "    print(f\"AI: {current_state['chat_history'][-1]['content']}\\n\")\n",
    "current_state"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "05d0b159-3fd8-48b1-9c15-88375a52d60b",
   "metadata": {},
   "source": [
    "With Burr we can continue where we left off easily!\n",
    "\n",
    "So why not run the conversation through some more?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "6366fdb4-5b08-4a8f-8f73-ffca67d2091a",
   "metadata": {},
   "source": [
    "# run it\n",
    "while True:\n",
    "    # this will ask for input:\n",
    "    question = input(\"What can I help you with?\\n\")\n",
    "    if question == \"exit\":\n",
    "        break\n",
    "    current_action, _, current_state = burr_application.run(\n",
    "        halt_before=[\"human_converse\"],\n",
    "        inputs={\"user_question\": question},\n",
    "    )\n",
    "    # we'll then see the AI response:\n",
    "    print(f\"AI: {current_state['chat_history'][-1]['content']}\\n\")\n",
    "current_state"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d98761be-7aa2-4378-8367-b084491ee8b1",
   "metadata": {},
   "source": [],
   "outputs": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
